-
Notifications
You must be signed in to change notification settings - Fork 262
[Bugfix]:Fix atomicadd auto vectorize identify var error #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRefactors AtomicAdd vectorization into a planner-driven pipeline (planner + plan result + rewriter), simplifies the Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Lower as AtomicAddNode::Lower
participant SIMT as SIMT Builder/Fuser
participant Collector as AtomicLoopNestCollector
participant Layout as Layout Infer
participant Planner as AtomicAddVectorizePlanner
participant Rewriter as AtomicAddVectorizeRewriter
participant IR as Resulting IR
Caller->>Lower: Lower(...)
Lower->>SIMT: Build & fuse SIMT loop
SIMT-->>Lower: Fused For
Lower->>Collector: Collect loop nest & buffer indices
Collector-->>Lower: Loop metadata
Lower->>Layout: Compute layout / predicate (InferLayout)
Layout-->>Lower: Layout + optional guard
Lower->>Planner: Plan(fused_For, compute_capability)
Note right of Planner #E6F2FF: Analyze AtomicAdd calls & dtypes → vector_size, dynamic, condition
Planner-->>Rewriter: PlanResult
Rewriter->>IR: Rewrite For (vectorized / guarded)
IR-->>Lower: Lowered vectorized loop
Lower-->>Caller: Final lowered IR
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a bug in identifying variables for atomicadd
auto-vectorization by introducing a more robust ParseIndex
function. The changes are a definite improvement over the previous, more brittle implementation. I've identified a potential issue in how multiple AtomicAdd
calls within a loop are handled, which could lead to incorrect behavior. My review includes a suggestion to make this logic more robust. Additionally, it's good to see that a previously failing test case has been re-enabled as part of this fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
testing/python/language/test_tilelang_language_atomic_add.py (1)
375-376
: Enable test: good; consider making it hardware-agnostic (float16) to avoid cc>=90 dependencyAtomicAddx4 for float32 is only selected when compute capability >= 90. On CI GPUs < 90 (e.g., A100 cc=80), this path may not vectorize and could cause flakiness for the tile-atomic path. Two options:
- Portable: call with float16 so vectorization is available broadly.
- Alternatively, gate/skip on device capability.
Apply this minimal change for portability:
-def test_tile_atomic_add(): - run_tile_atomic_add(8, 128, 128, 32, 32) +def test_tile_atomic_add(): + run_tile_atomic_add(8, 128, 128, 32, 32, dtype="float16")Also, consider removing or gating the debug prints in run_tile_atomic_add to keep test output clean (prints at Lines 58, 72, 73).
src/transform/atomicadd_vectorize.cc (4)
322-347
: ParseIndex is too strict; accept const-expr strides and avoid false negativesRequiring exactly one MulNode with a Var and an IntImm will miss common canonical forms:
- Stride may be a foldable const expr or come via casts (not a bare IntImm).
- Extra harmless multiplies like x*1 can appear pre-simplification.
- You only need a unique var*const match; other non-relevant muls shouldn’t invalidate the parse.
Refine by simplifying first, using as_const_int, and relaxing the check to “exactly one legal var*const mul” regardless of other muls:
- auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, - int &stride_out) -> bool { + auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, + int &stride_out) -> bool { int mul_count = 0, legal_mul_count = 0; stride_out = -1; var_out = PrimExpr(); - PostOrderVisit(idx, [&](const ObjectRef &obj) { + // Simplify to eliminate x*1 and fold-able constants. + arith::Analyzer az; + PrimExpr sidx = az.Simplify(idx); + PostOrderVisit(sidx, [&](const ObjectRef &obj) { if (const MulNode *mul = obj.as<MulNode>()) { mul_count++; - const VarNode *var = nullptr; - const IntImmNode *imm = nullptr; - if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) { - var_out = mul->a; - stride_out = imm->value; - legal_mul_count++; - } else if ((var = mul->b.as<VarNode>()) && - (imm = mul->a.as<IntImmNode>())) { - var_out = mul->b; - stride_out = imm->value; - legal_mul_count++; - } + const VarNode *var = nullptr; + const int64_t *c = nullptr; + if ((var = mul->a.as<VarNode>()) && (c = as_const_int(mul->b))) { + var_out = mul->a; + stride_out = static_cast<int>(*c); + legal_mul_count++; + } else if ((var = mul->b.as<VarNode>()) && (c = as_const_int(mul->a))) { + var_out = mul->b; + stride_out = static_cast<int>(*c); + legal_mul_count++; + } } }); - if (mul_count == 1 && legal_mul_count == 1) - return true; - return false; + return legal_mul_count == 1; };Note: this uses as_const_int and simplification. If not already available, include tvm/arith/analyzer.h (already included).
362-368
: Accumulate vectorize_size_max across multiple AtomicAdd sitesIf the loop body contains multiple AtomicAdd calls, you currently overwrite vectorize_size_max. Prefer taking the max to avoid under-vectorizing later calls.
- DataType dtype = bufload->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + DataType dtype = bufload->dtype; + vectorize_size_max = std::max( + vectorize_size_max, GetVectorizeSizeMax(compute_capability, dtype));You’ll need:
- Add at top: #include
362-368
: Guard against mis-identifying non-block vars as bx/byParseIndex will happily return any varconst (e.g., loop i1). Before accepting, assert the extracted vars are actual block indices (thread/block bindings) for safety, otherwise bail out. For example:
- Verify var_out.as() is bound in thread_binding as blockIdx.{x,y} (or matches expected bx/by symbols in this pass’ context).
- If that metadata isn’t available here, at least ensure both extracted vars differ and are not the loop var inside inner_for_.
This avoids rewriting with incorrect axes on more complex index expressions.
380-385
: Extra sanity checks before rewritingBefore constructing the rewriter:
- Ensure bx_var and by_var are Vars: if (!bx_var.as() || !by_var.as()) return for_node;
- Optionally ensure stride_x > 0 && stride_y > 0.
This prevents emitting malformed truncdiv/truncmod expressions.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/atomicadd_vectorize.cc
(2 hunks)testing/python/language/test_tilelang_language_atomic_add.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)
375-377
: Remove noisy debug prints before re-enabling this testReactivating
test_tile_atomic_add
now exercisesrun_tile_atomic_add
, which still containsprint(kernel.get_kernel_source())
plus dumps of both 128×128 tensors. That’s hundreds of thousands of characters on every run and will swamp CI logs without adding assertion value. Please drop or gate those prints before merging.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_atomic_add.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
src/transform/atomicadd_vectorize.cc
Outdated
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var, | ||
const Range &thread_bounds, int compute_capability) { | ||
|
||
auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add some comments for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)
35-46
: Code duplication:GetArchInt
is duplicated fromsrc/target/utils.cc
.The
GetArchInt
function is duplicated fromsrc/target/utils.cc
(lines 17-25) with a different implementation. The local version includes a fallback to return 0 for non-sm_
architectures, while the version insrc/target/utils.cc
usesICHECK
to enforce thesm_
prefix.Consider one of the following approaches:
- Preferred: Import and use the existing
GetArchInt
fromsrc/target/utils.cc
if the stricter validation is acceptable, or- Update the version in
src/target/utils.cc
to include the fallback behavior and use it consistently across the codebase.Apply this diff to use the existing function:
-static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr<String>("arch"); - ICHECK(s.defined()); - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { - arch_int = 0; - } - return arch_int; -}And update the include at the top of the file if not already present:
#include "../target/utils.h"
🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)
372-372
: Consider removing or adjusting the log level.The
LOG(INFO)
statement prints the vectorized loop IR to the console. This may be useful during development but could be noisy in production.Consider one of the following:
- Remove the log statement if it was added for debugging purposes only.
- Change to
VLOG(1)
or a higher verbosity level to reduce noise in production logs.- If this is intentional diagnostic output, add a comment explaining why it's logged at INFO level.
Apply this diff to change to verbose logging:
- LOG(INFO) << vectorized_thread_loop; + VLOG(1) << "Vectorized thread loop: " << vectorized_thread_loop;src/transform/atomicadd_vectorize.cc (2)
33-55
: Consider adding documentation forBufferIndiceSimplify
.The
BufferIndiceSimplify
class lacks documentation. Adding a brief comment explaining its purpose would improve maintainability.Apply this diff to add documentation:
+/// \brief Simplifies buffer load and store indices using an analyzer. +/// +/// This mutator visits BufferLoad and BufferStore nodes and simplifies +/// their indices by applying the analyzer's Simplify method to each index. class BufferIndiceSimplify : public StmtExprMutator {
174-231
: Consider adding documentation for therun()
method.The
run()
method implements complex loop transformation logic but lacks documentation explaining the transformation steps and the role ofloop_layout
andanalyzer
.Apply this diff to add documentation:
+ /// \brief Transform and vectorize the for loop using the provided layout. + /// + /// \param for_node The original For loop to transform + /// \param loop_layout Fragment describing the loop layout transformation + /// \param analyzer Analyzer for simplifying indices and binding loop variables + /// \return Transformed and vectorized For loop For run(For for_node, const Fragment &loop_layout, arith::Analyzer *analyzer) {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(1 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
VectorizeAtomicAdd
(308-343)VectorizeAtomicAdd
(308-310)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
BufferIndiceSimplify
(38-38)LoopPragmaUnroll
(201-205)LoopPragmaUnroll
(201-201)PartitionLoop
(61-105)PartitionLoop
(61-62)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: bot-task
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (5)
src/op/atomic_add.cc (1)
364-366
: Call site correctly updated.The call to
VectorizeAtomicAdd
now includes the requiredanalyzer
andloop_layout
parameters, matching the updated API signature.src/transform/atomicadd_vectorize.cc (3)
33-55
:BufferIndiceSimplify
implementation looks correct.The mutator correctly simplifies indices for both
BufferLoad
andBufferStore
nodes using the provided analyzer. The implementation properly usesCopyOnWrite()
and maps indices through the analyzer'sSimplify
method.
308-343
:VectorizeAtomicAdd
implementation looks correct.The function properly:
- Determines the maximum vectorization size based on compute capability and data type.
- Uses
PartitionLoop
to simplify the loop before planning.- Plans the vectorization with the planner.
- Returns the original loop if vectorization is not beneficial (
vectorize_hint == 1
).- Applies the vectorization rewrite with the provided analyzer and loop layout.
174-231
: Verify vectorization across all loop levels and add tests. Therun()
method divides every transformed loop extent byvector_size_
, not just the innermost; no existing tests cover this behavior—please confirm this is intentional and add multi-level loop vectorization test cases.src/transform/atomicadd_vectorize.h (1)
17-19
: Approve VectorizeAtomicAdd signature update
All call sites, includingsrc/op/atomic_add.cc
, have been updated to match the new parameters.
I refactored part of the AtomicAdd auto vectorization code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/atomic_add.cc
(1 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)tilelang/language/atomic.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/transform/atomicadd_vectorize.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
BufferIndiceSimplify
(38-38)LoopPragmaUnroll
(201-205)LoopPragmaUnroll
(201-201)PartitionLoop
(61-105)PartitionLoop
(61-62)
src/op/atomic_add.cc (1)
src/transform/atomicadd_vectorize.cc (2)
VectorizeAtomicAdd
(308-343)VectorizeAtomicAdd
(308-310)
tilelang/language/atomic.py (1)
tilelang/language/customize.py (1)
buffer_load_to_tile_region
(46-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
const BufferLoad dst_node = | ||
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | ||
const BufferLoad value_node = | ||
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downcast
expects an ObjectRef
, not a raw node pointer
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
doesn’t compile—the overload requires an ObjectRef
, but as<...>()
returns const BufferLoadNode*
. Use the original PrimExpr
instead (we already checked it’s a BufferLoad
), e.g.:
- const BufferLoad dst_node =
- Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
- const BufferLoad value_node =
- Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+ const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+ const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
Without this change the file fails to build.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
const BufferLoad dst_node = | |
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | |
const BufferLoad value_node = | |
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); | |
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); | |
const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]); |
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 249 to 253, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and similarly for
args[2], but Downcast expects an ObjectRef not a raw node pointer; replace the
.as<BufferLoadNode>() calls and pass the original PrimExprs (node->args[1] and
node->args[2]) directly to Downcast<BufferLoad>(), relying on the existing type
checks that confirmed these are BufferLoad instances so the Downcast will be
valid.
tilelang/language/atomic.py
Outdated
src_extent = list(get_extent(value)) | ||
dst_extent = list(get_extent(dst)) | ||
legal = True | ||
|
||
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent): | ||
legal = False | ||
elif (dst_extent and src_extent): | ||
if len(dst_extent) > len(src_extent): | ||
dst_extent_dims = [x for x in dst_extent if x != 1] | ||
if dst_extent_dims != src_extent: | ||
legal = False | ||
else: | ||
if dst_extent != src_extent: | ||
legal = False | ||
else: | ||
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | ||
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | ||
extent = max(dst_extent, src_extent) | ||
dst_extent = src_extent = extent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard get_extent()
results before wrapping in list(...)
get_extent()
still returns None
for scalar PrimExpr inputs (e.g., atomic_add(dst, 1)
), so list(get_extent(...))
raises a TypeError
before we can fall back to the extern path. This regresses the scalar code path.
Please keep the raw result, check for None
, and only convert to list
when defined before the length/shape logic.
- src_extent = list(get_extent(value))
- dst_extent = list(get_extent(dst))
+ src_extent_raw = get_extent(value)
+ dst_extent_raw = get_extent(dst)
+ src_extent = list(src_extent_raw) if src_extent_raw is not None else None
+ dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
legal = True
- if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
+ if dst_extent is None and src_extent is None:
+ legal = False
+ elif dst_extent is None:
+ dst_extent = [1] * len(src_extent)
+ elif src_extent is None:
+ src_extent = [1] * len(dst_extent)
+ elif len(dst_extent) < len(src_extent):
legal = False
Make sure the remaining branches avoid len(None)
as well.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
src_extent = list(get_extent(value)) | |
dst_extent = list(get_extent(dst)) | |
legal = True | |
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent): | |
legal = False | |
elif (dst_extent and src_extent): | |
if len(dst_extent) > len(src_extent): | |
dst_extent_dims = [x for x in dst_extent if x != 1] | |
if dst_extent_dims != src_extent: | |
legal = False | |
else: | |
if dst_extent != src_extent: | |
legal = False | |
else: | |
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | |
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | |
extent = max(dst_extent, src_extent) | |
dst_extent = src_extent = extent | |
src_extent_raw = get_extent(value) | |
dst_extent_raw = get_extent(dst) | |
src_extent = list(src_extent_raw) if src_extent_raw is not None else None | |
dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None | |
legal = True | |
if dst_extent is None and src_extent is None: | |
legal = False | |
elif dst_extent is None: | |
dst_extent = [1] * len(src_extent) | |
elif src_extent is None: | |
src_extent = [1] * len(dst_extent) | |
elif len(dst_extent) < len(src_extent): | |
legal = False | |
elif dst_extent and src_extent: | |
if len(dst_extent) > len(src_extent): | |
dst_extent_dims = [x for x in dst_extent if x != 1] | |
if dst_extent_dims != src_extent: | |
legal = False | |
else: | |
if dst_extent != src_extent: | |
legal = False | |
else: | |
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | |
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | |
extent = max(dst_extent, src_extent) | |
dst_extent = src_extent = extent |
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 193 to 211, guard the raw results of
get_extent() before calling list(...) to avoid TypeError on scalar PrimExprs:
assign src_extent_raw = get_extent(value) and dst_extent_raw = get_extent(dst),
check if each is None before converting to list; only call list(...) when the
raw extent is not None, and ensure every branch uses None-checks rather than
len(None) (e.g., when one side is None treat it as a scalar extent
[1]*len(other) or follow the extern-path fallback), then perform the same
dimension comparisons and the elementwise max logic operating on actual lists so
no branch ever calls len(...) or iterates over a None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)
52-205
: Fix AtomicAdd rewrites: size guard and correctDowncast
usage.Two problems here break compilation and correctness:
- We only check
args.size() >= 2
, yet we always readargs[2]
; AtomicAdd calls must therefore guard for>= 3
before touching the value operand.Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
is invalid—the overload expects anObjectRef
, not a raw node pointer. This reintroduces the build failure from the previous review.Please tighten the guards and pass the original
PrimExpr
toDowncast
:- if (node->op == builtin::call_extern() && node->args.size() >= 2) { + if (node->op == builtin::call_extern() && node->args.size() >= 3) { if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (func_name->value == "AtomicAdd") { const BufferLoadNode *temp_dst_node = node->args[1].as<BufferLoadNode>(); const BufferLoadNode *temp_value_node = node->args[2].as<BufferLoadNode>(); if (!temp_dst_node || !temp_value_node) { return StmtExprMutator::VisitExpr_(node); } - const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(2 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
tvm
(31-49)tl
(32-48)src/transform/atomicadd_vectorize.cc (17)
VectorizeAtomicAdd
(236-246)VectorizeAtomicAdd
(236-236)AtomicAddVectorizePlanner
(15-15)Plan
(17-44)Plan
(18-18)node
(162-184)node
(162-162)node
(186-228)node
(186-186)VisitStmt_
(46-49)VisitStmt_
(46-46)VisitExpr_
(51-71)VisitExpr_
(51-51)GetVectorizeSizeMax
(73-85)GetVectorizeSizeMax
(73-74)UpdateVectorSize
(87-127)UpdateVectorSize
(87-88)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
AtomicAddVectorizePlanner
(36-55)src/transform/loop_vectorize.cc (4)
indices
(157-189)indices
(157-157)IndiceCanVectorize
(257-298)IndiceCanVectorize
(257-259)
src/op/atomic_add.cc (6)
src/op/parallel.cc (8)
Lower
(184-187)Lower
(184-185)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/op/copy.cc (6)
Lower
(791-823)Lower
(791-791)Lower
(1776-1898)Lower
(1776-1777)MakeSIMTLoop
(299-344)MakeSIMTLoop
(299-299)src/op/fill.cc (4)
Lower
(171-206)Lower
(171-171)MakeSIMTLoop
(136-151)MakeSIMTLoop
(136-136)src/op/reduce.cc (4)
Lower
(152-318)Lower
(152-152)Lower
(413-437)Lower
(413-413)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(46-49)VisitStmt_
(46-46)VisitExpr_
(51-71)VisitExpr_
(51-51)VectorizeAtomicAdd
(236-246)VectorizeAtomicAdd
(236-236)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/atomic_add.cc
(3 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (5)
src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/op/copy.cc (6)
InferLayout
(399-498)InferLayout
(399-400)InferLayout
(1955-1958)InferLayout
(1955-1956)MakeSIMTLoop
(299-344)MakeSIMTLoop
(299-299)src/op/parallel.h (1)
ParallelOp
(151-158)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(45-48)VisitStmt_
(45-45)VisitExpr_
(50-70)VisitExpr_
(50-50)VectorizeAtomicAdd
(235-244)VectorizeAtomicAdd
(235-235)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
AtomicAddVectorizePlanner
(36-55)src/transform/loop_vectorize.cc (4)
indices
(157-189)indices
(157-157)IndiceCanVectorize
(257-298)IndiceCanVectorize
(257-259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
auto GetArchInt = [&](const Target &tgt) -> int { | ||
int arch_int = 0; | ||
if (auto s = tgt->GetAttr<String>("arch")) { | ||
std::string arch = s.value(); | ||
if (arch.rfind("sm_", 0) == 0) | ||
arch_int = std::stoi(arch.substr(3)); | ||
} | ||
return arch_int; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Drop the shadowing GetArchInt
lambda.
We already have a file-scope GetArchInt(Target)
(Lines 37-48). Redefining an identical lambda here is redundant, risks divergence, and silently bypasses the shared helper (e.g. any future validation fixes). Call the existing function instead of shadowing it locally.
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 401 to 409, there is a locally defined
GetArchInt lambda that shadows an existing file-scope GetArchInt(Target) (lines
37-48); remove this redundant lambda and replace any uses in this scope with a
direct call to the file-scope GetArchInt(tgt) helper so the shared
implementation (and any future validations) are preserved; ensure the lambda
definition is deleted and all call sites use GetArchInt(tgt) without adding new
duplicates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/op/atomic_add.cc (2)
375-383
: Drop the shadowingGetArchInt
lambda.This lambda duplicates the file-scope
GetArchInt(Target)
function fromsrc/target/utils.h
(already included on Line 13). Using a local shadow risks divergence and bypasses any future validation fixes in the shared implementation.Based on past review comments.
Apply this diff:
- auto GetArchInt = [&](const Target &tgt) -> int { - int arch_int = 0; - if (auto s = tgt->GetAttr<String>("arch")) { - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) - arch_int = std::stoi(arch.substr(3)); - } - return arch_int; - };And update line 455 and 506 to call
GetArchInt(target)
directly.
490-493
: Propagate the dynamic predicate.The planner's dynamic predicate is captured here but never used. When
plan.dynamic
is true andplan.condition
is defined, the finalvectorized_thread_loop
should be wrapped in a guard (e.g.,IfThenElse(pred, vectorized_thread_loop, thread_loop)
) so the vectorized path executes only when the condition holds.Based on past review comments.
If you intend to propagate the predicate, apply a diff similar to:
auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, GetArchInt(target)); + if (ret.predicate.defined()) { + return IfThenElse(ret.predicate.value(), vectorized_thread_loop, thread_loop); + } return vectorized_thread_loop;src/transform/atomicadd_vectorize.cc (1)
227-230
: Pass the original PrimExprs toDowncast
.
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
attempts to downcast a raw pointer, which fails to compile. Since Lines 220-224 already verify the args areBufferLoadNode*
, pass the originalPrimExpr
objects directly toDowncast
.Based on past review comments.
Apply this diff:
- const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(4 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
AtomicAddVectorizePlanner
(36-55)src/transform/loop_vectorize.cc (4)
indices
(157-189)indices
(157-157)IndiceCanVectorize
(257-298)IndiceCanVectorize
(257-259)
src/op/atomic_add.cc (3)
src/transform/atomicadd_vectorize.cc (10)
node
(188-210)node
(188-188)node
(212-254)node
(212-212)VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
tvm
(31-49)tl
(32-48)src/transform/atomicadd_vectorize.cc (17)
VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)AtomicAddVectorizePlanner
(15-15)Plan
(17-70)Plan
(18-18)node
(188-210)node
(188-188)node
(212-254)node
(212-212)VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)GetVectorizeSizeMax
(99-111)GetVectorizeSizeMax
(99-100)UpdateVectorSize
(113-153)UpdateVectorSize
(113-114)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-metal
🔇 Additional comments (7)
src/op/atomic_add.cc (2)
292-331
: LGTM!The lazy ParallelOp construction and layout validation logic is well-structured. The diagnostic output in the ICHECK will help users understand layout mismatches.
59-59
: Verify Python AtomicAdd arg ordering
No C++ callers require updates; confirm Python wrappers (e.g., tvm/tir/op/atomic_add.py) and tests now useargs[2]
forcoalesced_width
instead ofargs[3]
.src/transform/atomicadd_vectorize.cc (5)
17-70
: LGTM!The planner's PostOrderVisit approach cleanly separates the max vector size determination phase from the buffer access analysis phase. The early exit when
vectorize_size_max <= 1
avoids unnecessary work.
72-97
: LGTM!The visitor pattern correctly tracks the innermost loop and processes AtomicAdd buffer accesses to refine the vector size.
99-111
: LGTM!The compute capability and dtype-based vector size limits correctly reflect hardware constraints for atomic operations.
113-153
: LGTM!The vector size refinement logic correctly uses GCD operations to find alignment constraints and the
IndiceCanVectorize
check to verify vectorization safety. The dynamic fallback path properly captures the runtime predicate.
262-271
: LGTM!The entry point cleanly orchestrates the planner-driven vectorization flow. The early exit avoids unnecessary rewriting when vectorization isn't beneficial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/transform/atomicadd_vectorize.h (1)
48-50
: Remove unused planner state.
has_nonlocal_memory_access_
is introduced here but never referenced inatomicadd_vectorize.cc
, so it will trip-Wunused-private-field
(and become a hard error under-Werror
). Please drop it or wire it into the implementation.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.h (1)
src/transform/atomicadd_vectorize.cc (17)
VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)AtomicAddVectorizePlanner
(15-15)Plan
(17-70)Plan
(18-18)node
(188-210)node
(188-188)node
(212-254)node
(212-212)VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)GetVectorizeSizeMax
(99-111)GetVectorizeSizeMax
(99-100)UpdateVectorSize
(113-153)UpdateVectorSize
(113-114)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)
372-380
: Drop the shadowingGetArchInt
lambda.A file-scope
GetArchInt(Target)
already exists insrc/target/utils.cc
. Redefining an identical lambda here is redundant, risks divergence if the shared helper is updated with validation fixes, and bypasses the centralized implementation.Replace the lambda with a call to the existing file-scope
GetArchInt
:- auto GetArchInt = [&](const Target &tgt) -> int { - int arch_int = 0; - if (auto s = tgt->GetAttr<String>("arch")) { - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) - arch_int = std::stoi(arch.substr(3)); - } - return arch_int; - };And update the call sites (lines 452, 503) to use
GetArchInt(target)
directly.
487-504
: Propagate planner predicate for dynamic vectorization.
AtomicAddVectorizePlanner::Plan
can returndynamic=true
with acondition
guard. The predicate is captured at lines 487-490 but never used. Whendynamic
is true, the vectorized body must be wrapped with the guard before execution; otherwise the dynamic plan executes incorrectly.Apply this diff to wrap the vectorized loop with the predicate when present:
auto ret = AtomicAddInferLayout(transformed_loop, {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, GetArchInt(target)); - return vectorized_thread_loop; + if (ret.predicate.defined()) { + return IfThenElse(ret.predicate.value(), vectorized_thread_loop); + } + return vectorized_thread_loop; }
🧹 Nitpick comments (1)
src/op/atomic_add.cc (1)
398-409
: Consider adding consistency checks for buffer indices.Unlike
ParallelLoopNestVisitor
(seesrc/op/parallel.cc
lines 147-172),AtomicLoopNestCollector
does not verify that the same buffer is accessed with structurally equal indices across multiple access sites. If a buffer has inconsistent access patterns, the collector silently overwrites the indices, potentially masking layout inference bugs.Apply this diff to add consistency validation:
void VisitStmt_(const BufferStoreNode *op) final { if (op->buffer.scope() == "local.fragment") { + if (indice_map.count(op->buffer)) { + ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices)) + << op->buffer << ": " << op->indices << " vs " + << indice_map[op->buffer]; + } indice_map.Set(op->buffer, op->indices); writes.insert(op->buffer); } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode *op) final { if (op->buffer.scope() == "local.fragment") { + if (indice_map.count(op->buffer)) { + ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices)) + << op->buffer << ": " << op->indices << " vs " + << indice_map[op->buffer]; + } indice_map.Set(op->buffer, op->indices); } StmtExprVisitor::VisitExpr_(op); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/atomic_add.cc
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (3)
src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)
🔇 Additional comments (3)
src/op/atomic_add.cc (3)
16-16
: LGTM!The include for
loop_parallel_transform_utils.h
is correctly added to support the newParallelLoopTransformer::Substitute
call introduced in the refactored Lower method.
289-328
: LGTM!The InferLayout method correctly implements lazy initialization of the ParallelOp, validates fragment layout consistency between source and destination buffers, and properly delegates to the underlying ParallelOp for layout inference.
413-431
: LGTM!The
ComputeLoopLayoutFromBuffer
lambda correctly derives a fragment layout by computing the forward thread expression from buffer indices and binding to the thread range. TheAtomicInferResult
struct appropriately encapsulates the layout and optional predicate results from the inference process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)
378-386
: Drop the redundant localGetArchInt
lambda.This local lambda duplicates the file-scope
GetArchInt(Target)
function already available in../target/utils.h
(included at line 13). The duplication creates maintenance risk and bypasses any future improvements to the shared implementation.Apply this diff to remove the lambda and use the file-scope function:
- auto GetArchInt = [&](const Target &tgt) -> int { - int arch_int = 0; - if (auto s = tgt->GetAttr<String>("arch")) { - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) - arch_int = std::stoi(arch.substr(3)); - } - return arch_int; - };Then update the call sites at lines 458 and 509 to use the file-scope
GetArchInt(target)
directly (note: the function is already in scope via the include at line 13).
502-510
: Propagate the planner predicate for dynamic vectorization.
AtomicAddInferLayout
captures a predicate from the planner whendynamic=true
(lines 493-496) but this predicate is never used. When dynamic vectorization is planned, the vectorized body should be wrapped with the predicate guard to ensure correct execution.Apply this diff to wrap the vectorized loop with the predicate when present:
auto ret = AtomicAddInferLayout(transformed_loop, {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, GetArchInt(target)); - return vectorized_thread_loop; + if (ret.predicate.defined()) { + return IfThenElse(ret.predicate.value(), vectorized_thread_loop); + } else { + return vectorized_thread_loop; + }
🧹 Nitpick comments (1)
src/op/atomic_add.cc (1)
388-417
: Add index consistency validation inAtomicLoopNestCollector
.The collector sets
indice_map
without validating that repeated accesses to the same buffer use structurally equal indices. This differs fromParallelLoopNestVisitor
(insrc/op/parallel.cc
lines 147-172), which validates consistency withICHECK(StructuralEqual()(...)
. Silently overwriting indices can hide bugs where the same buffer is accessed with different patterns.Apply this diff to add validation:
void VisitStmt_(const BufferStoreNode *op) final { if (op->buffer.scope() == "local.fragment") { + if (indice_map.count(op->buffer)) { + ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices)) + << op->buffer << ": " << op->indices << " and " + << indice_map[op->buffer]; + } indice_map.Set(op->buffer, op->indices); writes.insert(op->buffer); } StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode *op) final { if (op->buffer.scope() == "local.fragment") { + if (indice_map.count(op->buffer)) { + ICHECK(StructuralEqual()(indice_map[op->buffer], op->indices)) + << op->buffer << ": " << op->indices << " and " + << indice_map[op->buffer]; + } indice_map.Set(op->buffer, op->indices); } StmtExprVisitor::VisitExpr_(op); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/atomic_add.cc
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (3)
src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
🔇 Additional comments (5)
src/op/atomic_add.cc (5)
16-16
: LGTM!The new include is necessary for
ParallelLoopTransformer
used in the refactored lowering pipeline.
295-334
: LGTM!The
InferLayout
implementation properly:
- Lazily constructs the ParallelOp when needed
- Validates fragment layout consistency between src and dst
- Delegates to the underlying ParallelOp for layout inference
375-376
: LGTM!The
ParallelLoopTransformer::Substitute
call appropriately transforms the fused loop before further processing.
419-432
: LGTM!The
ComputeLoopLayoutFromBuffer
lambda correctly computes a Fragment layout from buffer indices and thread bounds.
439-500
: The layout inference logic is sound, but predicate handling needs attention.The
AtomicAddInferLayout
implementation correctly:
- Collects loop nest information
- Selects the best buffer for layout inference
- Uses the planner to determine vector size
- Adjusts vector size based on
coalesced_width
annotation and divisibility constraints- Computes the final loop layout
However, the predicate captured at lines 493-496 is returned but not used by the caller (see next comment on lines 502-510).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/op/atomic_add.cc (2)
412-420
: Remove the redundantGetArchInt
lambda.A file-scope
GetArchInt(Target)
function already exists (imported via#include "../target/utils.h"
at line 13, defined in/src/target/utils.cc
). This local lambda duplicates that logic and shadows the shared implementation. Use the existing function directly instead.Apply this diff:
- auto GetArchInt = [&](const Target &tgt) -> int { - int arch_int = 0; - if (auto s = tgt->GetAttr<String>("arch")) { - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) - arch_int = std::stoi(arch.substr(3)); - } - return arch_int; - };Then update line 492 and 543 to call the file-scope function:
- int sm = GetArchInt(target); + int sm = tl::GetArchInt(target);- VectorizeAtomicAdd(thread_loop, GetArchInt(target)); + VectorizeAtomicAdd(thread_loop, tl::GetArchInt(target));
536-544
: Missing predicate guard for dynamic vectorization.The planner's predicate (captured at line 533) is never used. When
plan.dynamic
is true, the vectorized body must be wrapped with the predicate guard to ensure correctness. Currently, line 544 returns the vectorized loop unconditionally, which can produce incorrect results for dynamic plans.Apply this diff to wrap the vectorized loop when a predicate exists:
auto ret = AtomicAddInferLayout(transformed_loop, {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, tl::GetArchInt(target)); - return vectorized_thread_loop; + if (ret.predicate.defined()) { + return IfThenElse(ret.predicate.value(), vectorized_thread_loop); + } + return vectorized_thread_loop; }
🧹 Nitpick comments (2)
src/op/atomic_add.cc (2)
477-490
: Consider renamingread_src
for clarity.The variable
read_src
(line 477) doesn't verify whether the buffer is read-only; it simply selects the buffer with the highest rank fromindice_map
. Consider renaming it tolayout_src
orinfer_src
to better reflect its purpose.
523-524
: Remove unnecessaryremapped
variable.Line 523 assigns
loop
toremapped
without any transformation. This appears to be leftover code from a refactoring. Simplify by usingloop
directly at line 524.Apply this diff:
- const For &remapped = loop; - loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds); + loop_layout = PlanLoopPartition(loop, vec, args.thread_bounds);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/op/atomic_add.cc
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/op/atomic_add.cc (4)
src/op/copy.cc (6)
InferLayout
(399-498)InferLayout
(399-400)InferLayout
(1962-1965)InferLayout
(1962-1963)MakeSIMTLoop
(299-344)MakeSIMTLoop
(299-299)src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(72-75)VisitStmt_
(72-72)VisitExpr_
(77-97)VisitExpr_
(77-77)VectorizeAtomicAdd
(262-271)VectorizeAtomicAdd
(262-262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
🔇 Additional comments (3)
src/op/atomic_add.cc (3)
307-346
: LGTM: Well-structured layout inference method.The new
InferLayout
method follows good patterns:
- Lazy initialization of the ParallelOp cache
- Proper validation of fragment layout consistency between src and dst
- Clean delegation to the underlying ParallelOp
422-451
: LGTM: Clean visitor implementation.
AtomicLoopNestCollector
follows the established visitor pattern fromParallelLoopNestVisitor
in/src/op/parallel.cc
, appropriately simplified for atomic add operations.
453-466
: LGTM: Correct layout computation.The
ComputeLoopLayoutFromBuffer
lambda correctly derives loop layout from buffer layout usingForwardThread
and thread range binding, consistent with the pattern in/src/op/parallel.cc
.
Summary by CodeRabbit
New Features
Refactor
Breaking Changes